package com.mgface.metadata.design.udd;

import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.sql.SqlBasicCall;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlJoin;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlSelect;
import org.springframework.stereotype.Component;

import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

/**
 * @author wanyuxiang
 * @version 1.0
 * @project mddesign
 * @create 2021-06-07 14:33
 **/
@Component
@Slf4j
public class MgfaceSelectParser implements MgfaceParser {

    @Data
    private static class TableClass {
        private String tableName; //表名
        private String alias; //别名
    }

    @Data
    private static class SelectClass {
        private String tableAlias;//表的别名
        private String tableFiled;//查询的表字段
        private String asAlias;//查询的别名
    }

    @Data
    private static class WhereClass {
        private String whereTableAlias;//where条件别名
        private String whereTableField;//where条件字段
    }

    @Data
    private static class OnClass {
        private String OnTableAlias;//on条件别名
        private String OnTableField;//on条件字段
    }

    //解析表
    private static void paserTableName(SqlNode tbl, List<TableClass> list, List<OnClass> listOnClass) {
        if (tbl.getKind() == SqlKind.JOIN) {
            SqlJoin sqlJoin = (SqlJoin) tbl;

            //获得on条件
            SqlBasicCall sbc = (SqlBasicCall) sqlJoin.getCondition();
            SqlIdentifier sdf = (SqlIdentifier) sbc.operands[0];
            OnClass onClass = new OnClass();
            onClass.setOnTableAlias(sdf.names.get(0));
            onClass.setOnTableField(sdf.names.get(1));
            listOnClass.add(onClass);
            SqlIdentifier sdf1 = (SqlIdentifier) sbc.operands[1];
            OnClass onClass1 = new OnClass();
            onClass1.setOnTableAlias(sdf1.names.get(0));
            onClass1.setOnTableField(sdf1.names.get(1));
            listOnClass.add(onClass1);

            SqlNode left = sqlJoin.getLeft();
            SqlNode right = sqlJoin.getRight();
            paserTableName(left, list, listOnClass);
            paserTableName(right, list, listOnClass);
        } else if (tbl.getKind() == SqlKind.AS) {
            SqlBasicCall sqlBasicCall = (SqlBasicCall) tbl;
            TableClass tc = new TableClass();
            tc.setTableName(sqlBasicCall.operands[0].toString());
            tc.setAlias(sqlBasicCall.operands[1].toString());
            list.add(tc);
        }
    }

    @Override
    public Map<String, List<?>> extractMeta(SqlNode sqlNode, String sql) throws Exception {

        Map<String, List<?>> resultMap = new HashMap<>();
        SqlSelect sqlSelect = (SqlSelect) sqlNode;

        List<SelectClass> selectClassList = new ArrayList<>();
        //select params 条件
        SqlNodeList selectList = Objects.requireNonNull(sqlSelect).getSelectList();
        selectList.getList().forEach(x -> {
            if (SqlKind.AS.equals(x.getKind())) {
                SqlBasicCall sqlBasicCall = (SqlBasicCall) x;
                SqlNode sqlnodes = sqlBasicCall.operands[0];
                SqlIdentifier sqlidf = (SqlIdentifier) sqlnodes;
                SelectClass sc = new SelectClass();
                sc.setTableAlias(sqlidf.names.get(0));
                sc.setTableFiled(sqlidf.names.get(1));
                sc.setAsAlias(sqlBasicCall.operands[1].toString());
                selectClassList.add(sc);
            } else {
                throw new RuntimeException("select 查询条件必须使用as别名.");
            }
        });

        log.info("select 解析数据:{}", selectClassList);
        resultMap.put("select", selectClassList);


        SqlNode from = Objects.requireNonNull(sqlSelect).getFrom();

        //要求SQL 表必须要具有别名【目前不解析没有别名的表】
        List<TableClass> tableClassList = new ArrayList<>();
        List<OnClass> onClassList = new ArrayList<>();
        if (SqlKind.AS.equals(from.getKind())) {//单表查询
            SqlBasicCall sqlBasicCall = (SqlBasicCall) from;
            TableClass tc = new TableClass();
            tc.tableName = sqlBasicCall.operands[0].toString();
            tc.alias = sqlBasicCall.operands[1].toString();
            tableClassList.add(tc);
        } else if (SqlKind.JOIN.equals(from.getKind())) {//多表查询
            SqlJoin sqlJoin = (SqlJoin) from;

            //获得on条件
            SqlBasicCall sbc = (SqlBasicCall) sqlJoin.getCondition();
            SqlIdentifier sdf = (SqlIdentifier) sbc.operands[0];
            OnClass onClass = new OnClass();
            onClass.setOnTableAlias(sdf.names.get(0));
            onClass.setOnTableField(sdf.names.get(1));
            onClassList.add(onClass);
            SqlIdentifier sdf1 = (SqlIdentifier) sbc.operands[1];
            OnClass onClass1 = new OnClass();
            onClass1.setOnTableAlias(sdf1.names.get(0));
            onClass1.setOnTableField(sdf1.names.get(1));
            onClassList.add(onClass1);

            SqlNode left = sqlJoin.getLeft(); //左表（复杂查询）
            SqlNode right = sqlJoin.getRight(); //右表（单一查询表）
            paserTableName(left, tableClassList, onClassList); //解析
            paserTableName(right, tableClassList, onClassList);
        }

        log.info("table 解析数据:{}", tableClassList);
        resultMap.put("table", tableClassList);

        log.info("on 解析数据:{}", onClassList);
        resultMap.put("on", onClassList);

        SqlNode where = sqlSelect.getWhere();
        if (where != null) {
            List<WhereClass> whereClassList = new ArrayList<>();
            if (SqlKind.AND.equals(where.getKind())) {
                SqlBasicCall sqlBasicCall = (SqlBasicCall) where;
                for (SqlNode sqlNode1 : sqlBasicCall.operands) {
                    if (sqlNode1.getKind() == SqlKind.AND) {
                        SqlBasicCall sqlBasicCall2 = (SqlBasicCall) sqlNode1;
                        for (SqlNode sqlNode2 : sqlBasicCall2.operands) {
                            SqlBasicCall sq1 = (SqlBasicCall) sqlNode2;
                            SqlIdentifier ss = sq1.operand(0);
                            WhereClass whereClass = new WhereClass();
                            whereClass.setWhereTableAlias(ss.names.get(0));
                            whereClass.setWhereTableField(ss.names.get(1));
                            whereClassList.add(whereClass);
                        }
                    } else {
                        SqlBasicCall sq1 = (SqlBasicCall) sqlNode1;
                        SqlIdentifier ss = sq1.operand(0);
                        WhereClass whereClass = new WhereClass();
                        whereClass.setWhereTableAlias(ss.names.get(0));
                        whereClass.setWhereTableField(ss.names.get(1));
                        whereClassList.add(whereClass);
                    }
                }
            }
            log.info("where 解析数据:{}", whereClassList);
            resultMap.put("where", whereClassList);
        }
        return resultMap;

    }

    @Override
    public Map<String, Map<String, String>> authVerified(DataSource ds, Map<String, List<?>> data) throws Exception {
        //1.检查是否有对象操作权限
        Connection conn = ds.getConnection();
        Map<String, String> tt1 = new HashMap<>();

        Map<String, Map<String, String>> total = new HashMap<>();

        Map<String, String> tt2 = new HashMap<>();

        data.get("table").forEach(e -> {
            TableClass tc = (TableClass) e;
            try {
                String sql = "select a.ObjID as objid from objects as a where a.AppID = ? and a.OrgID = ? and a.ObjName = ? and a.IsEnabled = 'Y'";
                PreparedStatement ptmt = conn.prepareStatement(sql, ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY);
                ptmt.setString(1, "1");
                ptmt.setString(2, "A00001");
                ptmt.setString(3, tc.getTableName());
                ResultSet rs = ptmt.executeQuery();
                String objID = null;
                if (rs.first()) {
                    objID = rs.getString("objid");
                }
                Objects.requireNonNull(objID, "对象ID不能为null");
                tt1.put(tc.getAlias(), objID);
                //把逻辑表替换成物理表
                tt2.put(tc.getTableName(), "datas");
            } catch (Exception e1) {
                log.info(e1.getMessage());
            }
        });
        total.put("table", tt2);

        //2.检查是否具有字段查询权限
        Map<String, String> tt3 = new HashMap<>();
        data.get("select").forEach(e -> {
            SelectClass sc = (SelectClass) e;
            try {
                String sql = "select a.FieldNum as fieldnum from fileds as a where a.ObjID = ? and a.OrgID = ? and a.FieldName = ? and a.IsEnabled = 'Y'";
                PreparedStatement ptmt = conn.prepareStatement(sql, ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY);
                ptmt.setString(1, tt1.get(sc.getTableAlias()));
                ptmt.setString(2, "A00001");
                ptmt.setString(3, sc.getTableFiled());
                ResultSet rs = ptmt.executeQuery();
                String fieldnum = null;
                if (rs.first()) {
                    fieldnum = rs.getString("fieldnum");
                }
                Objects.requireNonNull(fieldnum, "对象fieldnum不能为null");
                tt3.put(String.format("%s.%s", sc.getTableAlias(), sc.getTableFiled()), String.format("%s.Value%s", sc.getTableAlias(), fieldnum));
            } catch (Exception e1) {
                log.info(e1.getMessage());
            }
        });

        //检查where条件的数据
        Optional.ofNullable(data.get("where")).orElse(new ArrayList<>()).forEach(e -> {
            WhereClass sc = (WhereClass) e;
            try {
                String sql = "select a.FieldNum as fieldnum from fileds as a where a.ObjID = ? and a.OrgID = ? and a.FieldName = ? and a.IsEnabled = 'Y'";
                PreparedStatement ptmt = conn.prepareStatement(sql, ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY);
                ptmt.setString(1, tt1.get(sc.getWhereTableAlias()));
                ptmt.setString(2, "A00001");
                ptmt.setString(3, sc.getWhereTableField());
                ResultSet rs = ptmt.executeQuery();
                String fieldnum = null;
                if (rs.first()) {
                    fieldnum = rs.getString("fieldnum");
                }
                Objects.requireNonNull(fieldnum, "对象fieldnum不能为null");
                tt3.put(String.format("%s.%s", sc.getWhereTableAlias(), sc.getWhereTableField()), String.format("%s.Value%s", sc.getWhereTableAlias(), fieldnum));
            } catch (Exception e1) {
                log.info(e1.getMessage());
            }
        });

        //检查on条件的数据
        data.get("on").forEach(e -> {
            OnClass sc = (OnClass) e;
            try {
                String sql = "select a.FieldNum as fieldnum from fileds as a where a.ObjID = ? and a.OrgID = ? and a.FieldName = ? and a.IsEnabled = 'Y'";
                PreparedStatement ptmt = conn.prepareStatement(sql, ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY);
                ptmt.setString(1, tt1.get(sc.getOnTableAlias()));
                ptmt.setString(2, "A00001");
                ptmt.setString(3, sc.getOnTableField());
                ResultSet rs = ptmt.executeQuery();
                String fieldnum = null;
                if (rs.first()) {
                    fieldnum = rs.getString("fieldnum");
                }
                Objects.requireNonNull(fieldnum, "对象fieldnum不能为null");
                tt3.put(String.format("%s.%s", sc.getOnTableAlias(), sc.getOnTableField()), String.format("%s.Value%s", sc.getOnTableAlias(), fieldnum));
            } catch (Exception e1) {
                log.info(e1.getMessage());
            }
        });
        total.put("fields", tt3);
        return total;

    }

    @Override
    public String physicalQuery(String sql, Map<String, Map<String, String>> data) throws Exception {
        //替换select查询字段
        Map<String, String> v1 = data.get("fields");
        for (Map.Entry<String, String> entry : v1.entrySet()) {
            sql = sql.replace(entry.getKey(), entry.getValue());
        }

        Map<String, String> v2 = data.get("table");
        for (Map.Entry<String, String> entry : v2.entrySet()) {
            sql = sql.replace(entry.getKey(), entry.getValue());
        }
        log.info("最终形成的SQL语句:" + sql);
        return sql;
    }
}
